{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, BatchNormalization, Multiply\n",
    "from keras import backend as K\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### graph 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "random_seed = 10003\n",
    "data_in_shape = (8, 8, 2)\n",
    "\n",
    "input_layer_0 = Input(shape=data_in_shape)\n",
    "branch_0 = Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True)(input_layer_0)\n",
    "\n",
    "input_layer_1 = Input(shape=data_in_shape)\n",
    "branch_1 = Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True)(input_layer_1)\n",
    "\n",
    "output_layer = Multiply()([branch_0, branch_1])\n",
    "model = Model(inputs=[input_layer_0, input_layer_1], outputs=output_layer)\n",
    "\n",
    "data_in = []\n",
    "for i in range(2):\n",
    "    np.random.seed(random_seed + i)\n",
    "    data_in.append(np.expand_dims(2 * np.random.random(data_in_shape) - 1, axis=0))\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(data_in)\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = [format_decimal(data_in[i].ravel().tolist()) for i in range(2)]\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['graph_03'] = {\n",
    "    'inputs': [{'data': data_in_formatted[i], 'shape': data_in_shape} for i in range(2)],\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filename = '../../test/data/graph/03.json'\n",
    "if not os.path.exists(os.path.dirname(filename)):\n",
    "    os.makedirs(os.path.dirname(filename))\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(DATA, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"graph_03\": {\"inputs\": [{\"data\": [0.841248, 0.161158, 0.958983, 0.490417, 0.824656, 0.052908, 0.775934, 0.147982, -0.867877, -0.555668, -0.331903, 0.40502, -0.595569, -0.471818, -0.559532, -0.57215, -0.89974, -0.573032, -0.35398, -0.705059, -0.927461, -0.650224, -0.563291, -0.79357, 0.061125, 0.136481, -0.97837, 0.072996, -0.878323, 0.62697, -0.667599, -0.553128, 0.009751, -0.264958, -0.054374, -0.140509, 0.513801, 0.652987, 0.027122, 0.024481, -0.967907, 0.249529, 0.127448, -0.037277, -0.270137, 0.320024, -0.432672, 0.754105, -0.324482, 0.154028, -0.336313, -0.608767, 0.80611, -0.203209, -0.131729, 0.218894, -0.561637, -0.664737, 0.711306, 0.08824, -0.771977, 0.860109, 0.782578, -0.50916, -0.445579, -0.333019, 0.295793, 0.565551, -0.946729, -0.858205, -0.673907, -0.061506, -0.579362, 0.140314, -0.46189, 0.959434, -0.90153, -0.331571, 0.02769, -0.090527, -0.646693, 0.697831, 0.327324, 0.953899, 0.434714, -0.018754, 0.679171, 0.450538, 0.541374, -0.928426, 0.595931, -0.575864, -0.733921, -0.429094, -0.522865, -0.612979, -0.265133, 0.961394, 0.327196, 0.305523, -0.361009, 0.755892, -0.38218, -0.424697, 0.478128, -0.304299, -0.844554, 0.158311, 0.435603, -0.562215, -0.722012, -0.087758, -0.46262, -0.443516, 0.190205, -0.135997, -0.768256, 0.611786, 0.763558, -0.818832, 0.090036, 0.860763, 0.641142, -0.294615, -0.63098, 0.80676, -0.158879, -0.224822], \"shape\": [8, 8, 2]}, {\"data\": [-0.592446, -0.52773, 0.08531, -0.949905, -0.127145, 0.411354, 0.419711, -0.622227, 0.872841, -0.88662, 0.701828, 0.153318, 0.327738, 0.178798, 0.813647, -0.366809, 0.712692, 0.986871, 0.668377, 0.736488, -0.264375, 0.19263, -0.308472, 0.407634, -0.951991, 0.764595, -0.73637, 0.313222, 0.629502, 0.444145, -0.198403, 0.174327, -0.411546, 0.35668, -0.479344, -0.451058, 0.094248, -0.172221, 0.8534, 0.999072, 0.705663, 0.149181, -0.316913, 0.880756, 0.17336, 0.883104, -0.88683, -0.455842, -0.982796, -0.645087, -0.728562, -0.492119, -0.941125, -0.696325, 0.703916, 0.751858, -0.828058, 0.145984, 0.967902, 0.566607, 0.620443, 0.060608, 0.960336, 0.077866, -0.260331, -0.995759, 0.872716, 0.516793, -0.53123, 0.709423, -0.436639, 0.143448, -0.351875, -0.464221, 0.994688, -0.157409, -0.233078, 0.572034, -0.951472, 0.079706, -0.750931, -0.230296, -0.489441, -0.219935, -0.686203, -0.24279, 0.567421, -0.961362, -0.966846, 0.618548, 0.987486, 0.167981, -0.901755, 0.067314, 0.689805, 0.720812, -0.288329, 0.793488, -0.940847, 0.267783, 0.778441, -0.929597, -0.553821, -0.702763, 0.778919, 0.428746, -0.098998, 0.345026, -0.220246, -0.720604, 0.355463, -0.364463, 0.039299, 0.426763, -0.114002, -0.388025, 0.016499, -0.40977, 0.068633, -0.835767, -0.263438, 0.161918, -0.050737, -0.40455, 0.048353, -0.375367, -0.112814, 0.742089], \"shape\": [8, 8, 2]}], \"weights\": [{\"data\": [0.841248, 0.161158, 0.958983, 0.490417, 0.824656, 0.052908, 0.775934, 0.147982, -0.867877, -0.555668, -0.331903, 0.40502, -0.595569, -0.471818, -0.559532, -0.57215, -0.89974, -0.573032, -0.35398, -0.705059, -0.927461, -0.650224, -0.563291, -0.79357, 0.061125, 0.136481, -0.97837, 0.072996, -0.878323, 0.62697, -0.667599, -0.553128, 0.009751, -0.264958, -0.054374, -0.140509, 0.513801, 0.652987, 0.027122, 0.024481, -0.967907, 0.249529, 0.127448, -0.037277, -0.270137, 0.320024, -0.432672, 0.754105, -0.324482, 0.154028, -0.336313, -0.608767, 0.80611, -0.203209, -0.131729, 0.218894, -0.561637, -0.664737, 0.711306, 0.08824, -0.771977, 0.860109, 0.782578, -0.50916, -0.445579, -0.333019, 0.295793, 0.565551, -0.946729, -0.858205, -0.673907, -0.061506], \"shape\": [3, 3, 2, 4]}, {\"data\": [-0.592446, -0.52773, 0.08531, -0.949905], \"shape\": [4]}, {\"data\": [-0.905266, -0.940244, 0.975308, -0.726779, 0.318662, -0.747347, -0.333653, -0.875921, 0.959097, -0.046026, 0.500018, -0.914037, 0.884544, 0.190366, -0.635631, -0.160835, -0.946491, 0.638743, -0.403733, 0.431154, 0.771673, -0.000895, -0.667179, -0.734761, 0.366933, -0.524729, 0.171283, 0.611351, 0.477013, 0.586021, 0.75193, -0.305236, -0.374709, 0.756282, 0.61714, -0.823425, -0.135917, -0.45961, -0.35282, 0.165981, -0.875342, -0.910735, 0.54216, 0.791704, -0.363715, -0.379062, -0.778289, -0.503017, -0.498858, 0.73821, -0.560404, 0.383332, -0.162873, 0.462676, -0.888228, 0.603298, -0.211376, -0.410015, 0.969717, -0.800772, -0.326232, -0.903871, -0.472227, 0.527646, 0.845871, -0.555119, -0.242145, -0.720775, 0.230819, 0.098654, -0.136847, -0.75402], \"shape\": [3, 3, 2, 4]}, {\"data\": [0.87859, 0.511819, -0.499979, 0.899103], \"shape\": [4]}], \"expected\": {\"data\": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.748307, 0.0, 0.220511, 0.0, 0.0, 0.0, 0.0, 0.654918, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03624, 0.0, 0.862481, 0.0, 0.0, 1.340629, 0.0, 0.0, 0.0, 0.653713, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.039703, 0.0, 0.0, 0.0, 0.0, 1.271317, 5.29952, 0.0, 0.0, 0.0, 0.026066, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.884795, 0.585572, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.020983, 0.0, 0.298598, 0.0, 0.759136, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.630027, 0.902179, 3.959464, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.384447, 0.0, 0.0, 0.0, 0.545168, 0.0, 0.0, 0.583308, 0.0, 1.972604, 0.0, 0.618443, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.210449, 0.0, 0.0, 0.0, 4.442462, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.110053, 0.934373, 0.0, 0.0, 0.0, 0.0, 0.082385, 0.0, 0.0, 0.0, 0.0, 0.664974, 0.0, 0.0, 0.186025, 0.0, 0.689336, 1.037516, 0.0], \"shape\": [6, 6, 4]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
